{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import networkx as nx\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from torch_geometric.utils import to_networkx\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert(data):\n",
    "    G = to_networkx(data)\n",
    "    G.graph['label'] = data.y.item()\n",
    "    return nx.to_undirected(G)\n",
    "\n",
    "\n",
    "def dataset_to_graphs(dataset):\n",
    "    graphs = []\n",
    "    for data in dataset:\n",
    "        graphs.append(convert(data))\n",
    "    return graphs\n",
    "    \n",
    "\n",
    "def check(graphs):\n",
    "    num_iso_pairs = 0\n",
    "    num_inconsistent_labels = 0\n",
    "    num_graphs = len(graphs)\n",
    "    combinations = itertools.combinations(range(num_graphs), 2)\n",
    "    \n",
    "    for (i1, i2) in combinations:\n",
    "        G1, G2 = graphs[i1], graphs[i2]\n",
    "        label1, label2 = G1.graph['label'], G2.graph['label']\n",
    "        \n",
    "        if nx.is_isomorphic(G1, G2):\n",
    "            num_iso_pairs += 1\n",
    "            if label1 != label2:\n",
    "                num_inconsistent_labels += 1\n",
    "    \n",
    "    print(f\"number of isomorphic pairs: {num_iso_pairs}\") \n",
    "    print(f\"number of isomorphic pairs with inconsistent labels: {num_inconsistent_labels}\")\n",
    "    print(f\"ratio of inconsistently labelled isomorphic pairs vs. isomorphic pairs: {num_inconsistent_labels / num_iso_pairs:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## Checking dataset IMDB-BINARY ##############\n",
      "number of isomorphic pairs: 3356\n",
      "number of isomorphic pairs with inconsistent labels: 1119\n",
      "ratio of inconsistently labelled isomorphic pairs vs. isomorphic pairs: 0.3334\n"
     ]
    }
   ],
   "source": [
    "dataset_name = \"IMDB-BINARY\"\n",
    "\n",
    "print(f\"############## Checking dataset {dataset_name} ##############\")\n",
    "dataset = TUDataset(f'tmp/{dataset_name}', dataset_name)\n",
    "graphs = dataset_to_graphs(dataset)\n",
    "\n",
    "# WARNING: this might take several minutes depending on your hardware\n",
    "check(graphs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def bfs_seq(G, start_id):\n",
    "    \"\"\" taken from https://github.com/JiaxuanYou/graph-generation/blob/master/data.py \"\"\"\n",
    "    dictionary = dict(nx.bfs_successors(G, start_id))\n",
    "    start = [start_id]\n",
    "    output = [start_id]\n",
    "    while len(start) > 0:\n",
    "        next = []\n",
    "        while len(start) > 0:\n",
    "            current = start.pop(0)\n",
    "            neighbor = dictionary.get(current)\n",
    "            if neighbor is not None:\n",
    "                next = next + neighbor\n",
    "        output = output + next\n",
    "        start = next\n",
    "    return output\n",
    "\n",
    "# 10 and 710 have different labels, but are isomorphic\n",
    "G1, G2 = graphs[10], graphs[710]\n",
    "\n",
    "# reorder nodes\n",
    "seq1, seq2 = bfs_seq(G1, 0), bfs_seq(G2, 0)\n",
    "G2 = nx.relabel_nodes(G2, {n:m for n, m in zip(seq2, seq1)})\n",
    "print(f\"G1 label: {G1.graph['label']} - G2 label: {G2.graph['label']}\")\n",
    "\n",
    "fig, axs = plt.subplots(1, 2)\n",
    "pos = nx.random_layout(G1, seed=42)\n",
    "nx.draw_networkx(G1, pos=pos, ax=axs.flat[0])\n",
    "nx.draw_networkx(G2, pos=pos, ax=axs.flat[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
